{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6fa6fb7f",
   "metadata": {},
   "source": [
    "# 관련성 체크 추가\n",
    "\n",
    "**절차**\n",
    "\n",
    "1. Naive RAG 수행\n",
    "2. (이번 튜토리얼) 검색된 문서에 대한 관련성 체크(Groundedness Check) 추가\n",
    "\n",
    "**참고**\n",
    "\n",
    "- 이전 튜토리얼에서 확장된 내용이므로, 겹치는 부분이 있을 수 있습니다. 부족한 설명은 이전 튜토리얼을 참고해주세요.\n",
    "\n",
    "![langgraph-add-relevance-check](assets/langgraph-add-relevance-check.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f21c872b",
   "metadata": {},
   "source": [
    "## 환경 설정"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4d2a11aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -U langchain-teddynote"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064d5c8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# API 키를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API 키 정보 로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "562b0043",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install -qU langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH17-LangGraph-Structures\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06468c1c",
   "metadata": {},
   "source": [
    "## 기본 PDF 기반 Retrieval Chain 생성\n",
    "\n",
    "여기서는 PDF 문서를 기반으로 Retrieval Chain 을 생성합니다. 가장 단순한 구조의 Retrieval Chain 입니다.\n",
    "\n",
    "단, LangGraph 에서는 Retirever 와 Chain 을 따로 생성합니다. 그래야 각 노드별로 세부 처리를 할 수 있습니다.\n",
    "\n",
    "**참고**\n",
    "\n",
    "- 이전 튜토리얼에서 다룬 내용이므로, 자세한 설명은 생략합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f905df18",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rag.pdf import PDFRetrievalChain\n",
    "\n",
    "# PDF 문서를 로드합니다.\n",
    "pdf = PDFRetrievalChain([\"data/SPRI_AI_Brief_2023년12월호_F.pdf\"]).create_chain()\n",
    "\n",
    "# retriever와 chain을 생성합니다.\n",
    "pdf_retriever = pdf.retriever\n",
    "pdf_chain = pdf.chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d047f938",
   "metadata": {},
   "source": [
    "## State 정의\n",
    "\n",
    "`State`: Graph 의 노드와 노드 간 공유하는 상태를 정의합니다.\n",
    "\n",
    "일반적으로 `TypedDict` 형식을 사용합니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de58d48d",
   "metadata": {},
   "source": [
    "이번에는 상태(State)에 관련성(relevance) 체크 결과를 추가합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f19a3df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated, TypedDict\n",
    "from langgraph.graph.message import add_messages\n",
    "\n",
    "\n",
    "# GraphState 상태 정의\n",
    "class GraphState(TypedDict):\n",
    "    question: Annotated[str, \"Question\"]  # 질문\n",
    "    context: Annotated[str, \"Context\"]  # 문서의 검색 결과\n",
    "    answer: Annotated[str, \"Answer\"]  # 답변\n",
    "    messages: Annotated[list, add_messages]  # 메시지(누적되는 list)\n",
    "    relevance: Annotated[str, \"Relevance\"]  # 관련성"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c56d4095",
   "metadata": {},
   "source": [
    "## 노드(Node) 정의\n",
    "\n",
    "- `Nodes`: 각 단계를 처리하는 노드입니다. 보통은 Python 함수로 구현합니다. 입력과 출력이 상태(State) 값입니다.\n",
    "  \n",
    "**참고**  \n",
    "\n",
    "- `State`를 입력으로 받아 정의된 로직을 수행한 후 업데이트된 `State`를 반환합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9ef0c055",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_teddynote.evaluator import GroundednessChecker\n",
    "from langchain_teddynote.messages import messages_to_history\n",
    "from rag.utils import format_docs\n",
    "\n",
    "\n",
    "# 문서 검색 노드\n",
    "def retrieve_document(state: GraphState) -> GraphState:\n",
    "    # 질문을 상태에서 가져옵니다.\n",
    "    latest_question = state[\"question\"]\n",
    "\n",
    "    # 문서에서 검색하여 관련성 있는 문서를 찾습니다.\n",
    "    retrieved_docs = pdf_retriever.invoke(latest_question)\n",
    "\n",
    "    # 검색된 문서를 형식화합니다.(프롬프트 입력으로 넣어주기 위함)\n",
    "    retrieved_docs = format_docs(retrieved_docs)\n",
    "\n",
    "    # 검색된 문서를 context 키에 저장합니다.\n",
    "    return GraphState(context=retrieved_docs)\n",
    "\n",
    "\n",
    "# 답변 생성 노드\n",
    "def llm_answer(state: GraphState) -> GraphState:\n",
    "    # 질문을 상태에서 가져옵니다.\n",
    "    latest_question = state[\"question\"]\n",
    "\n",
    "    # 검색된 문서를 상태에서 가져옵니다.\n",
    "    context = state[\"context\"]\n",
    "\n",
    "    # 체인을 호출하여 답변을 생성합니다.\n",
    "    response = pdf_chain.invoke(\n",
    "        {\n",
    "            \"question\": latest_question,\n",
    "            \"context\": context,\n",
    "            \"chat_history\": messages_to_history(state[\"messages\"]),\n",
    "        }\n",
    "    )\n",
    "\n",
    "    # 생성된 답변, (유저의 질문, 답변) 메시지를 상태에 저장합니다.\n",
    "    return {\n",
    "        \"answer\": response,\n",
    "        \"messages\": [(\"user\", latest_question), (\"assistant\", response)],\n",
    "    }\n",
    "\n",
    "\n",
    "# 관련성 체크 노드\n",
    "def relevance_check(state: GraphState) -> GraphState:\n",
    "    # 관련성 평가기를 생성합니다.\n",
    "    question_retrieval_relevant = GroundednessChecker(\n",
    "        llm=ChatOpenAI(model=\"gpt-4o-mini\", temperature=0), target=\"question-retrieval\"\n",
    "    ).create()\n",
    "\n",
    "    # 관련성 체크를 실행(\"yes\" or \"no\")\n",
    "    response = question_retrieval_relevant.invoke(\n",
    "        {\"question\": state[\"question\"], \"context\": state[\"context\"]}\n",
    "    )\n",
    "\n",
    "    print(\"==== [RELEVANCE CHECK] ====\")\n",
    "    print(response.score)\n",
    "\n",
    "    # 참고: 여기서의 관련성 평가기는 각자의 Prompt 를 사용하여 수정할 수 있습니다. 여러분들의 Groundedness Check 를 만들어 사용해 보세요!\n",
    "    return {\"relevance\": response.score}\n",
    "\n",
    "\n",
    "def is_relevant(state: GraphState) -> GraphState:\n",
    "    if state[\"relevance\"] == \"yes\":\n",
    "        return \"relevant\"\n",
    "    else:\n",
    "        return \"not relevant\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f7785d",
   "metadata": {},
   "source": [
    "## Edges\n",
    "\n",
    "- `Edges`: 현재 `State`를 기반으로 다음에 실행할 `Node`를 결정하는 Python 함수.\n",
    "\n",
    "일반 엣지, 조건부 엣지 등이 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a6015807",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END, StateGraph\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "\n",
    "# 그래프 정의\n",
    "workflow = StateGraph(GraphState)\n",
    "\n",
    "# 노드 추가\n",
    "workflow.add_node(\"retrieve\", retrieve_document)\n",
    "# 관련성 체크 노드 추가\n",
    "workflow.add_node(\"relevance_check\", relevance_check)\n",
    "workflow.add_node(\"llm_answer\", llm_answer)\n",
    "\n",
    "# 엣지 추가\n",
    "workflow.add_edge(\"retrieve\", \"relevance_check\")  # 검색 -> 관련성 체크\n",
    "\n",
    "\n",
    "# 조건부 엣지를 추가합니다.\n",
    "workflow.add_conditional_edges(\n",
    "    \"relevance_check\",  # 관련성 체크 노드에서 나온 결과를 is_relevant 함수에 전달합니다.\n",
    "    is_relevant,\n",
    "    {\n",
    "        \"relevant\": \"llm_answer\",  # 관련성이 있으면 답변을 생성합니다.\n",
    "        \"not relevant\": \"retrieve\",  # 관련성이 없으면 다시 검색합니다.\n",
    "    },\n",
    ")\n",
    "\n",
    "workflow.add_edge(\"llm_answer\", END)\n",
    "\n",
    "# 그래프 진입점 설정\n",
    "workflow.set_entry_point(\"retrieve\")\n",
    "\n",
    "# 체크포인터 설정\n",
    "memory = MemorySaver()\n",
    "\n",
    "# 그래프 컴파일\n",
    "app = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a15c32",
   "metadata": {},
   "source": [
    "컴파일한 그래프를 시각화 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e09251d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_teddynote.graphs import visualize_graph\n",
    "\n",
    "visualize_graph(app)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "110daa16",
   "metadata": {},
   "source": [
    "## 그래프 실행\n",
    "\n",
    "- `config` 파라미터는 그래프 실행 시 필요한 설정 정보를 전달합니다.\n",
    "- `recursion_limit`: 그래프 실행 시 재귀 최대 횟수를 설정합니다.\n",
    "- `inputs`: 그래프 실행 시 필요한 입력 정보를 전달합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2698eaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnableConfig\n",
    "from langchain_teddynote.messages import stream_graph, invoke_graph, random_uuid\n",
    "\n",
    "# config 설정(재귀 최대 횟수, thread_id)\n",
    "config = RunnableConfig(recursion_limit=20, configurable={\"thread_id\": random_uuid()})\n",
    "\n",
    "# 질문 입력\n",
    "inputs = GraphState(question=\"앤스로픽에 투자한 기업과 투자금액을 알려주세요.\")\n",
    "\n",
    "# 그래프 실행\n",
    "invoke_graph(app, inputs, config, [\"relevance_check\", \"llm_answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdbd17ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 그래프 스트리밍 출력\n",
    "stream_graph(app, inputs, config, [\"relevance_check\", \"llm_answer\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a16ba031",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = app.get_state(config).values\n",
    "\n",
    "print(f'Question: {outputs[\"question\"]}')\n",
    "print(\"===\" * 20)\n",
    "print(f'Answer:\\n{outputs[\"answer\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b734deb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(outputs[\"relevance\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4af077a6",
   "metadata": {},
   "source": [
    "하지만, 검색 결과의 `relevance_check` 가 실패할 경우, 반복하여 동일한 Query 가 다시 retrieve 노드로 들어가는 상황이 발생합니다.\n",
    "\n",
    "반복하여 동일한 Query 가 다시 retrieve 노드로 들어가면, 동일한 검색 결과로 이어지기 때문에, 결국 재귀 상태에 빠지게 됩니다.\n",
    "\n",
    "혹시 모를 재귀 상태를 방지하기 위해, 재귀 최대 횟수(`recursion_limit`)를 설정합니다. 그리고, 에러 처리를 위하여 `GraphRecursionError` 를 처리합니다.\n",
    "\n",
    "다음 튜토리얼에서는 이와 같이 재귀상태로 빠지는 문제를 해결하는 방법을 다루겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12129e2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.errors import GraphRecursionError\n",
    "from langchain_core.runnables import RunnableConfig\n",
    "\n",
    "# config 설정(재귀 최대 횟수, thread_id)\n",
    "config = RunnableConfig(recursion_limit=10, configurable={\"thread_id\": random_uuid()})\n",
    "\n",
    "# 질문 입력\n",
    "inputs = GraphState(question=\"테디노트의 랭체인 튜토리얼에 대한 정보를 알려주세요.\")\n",
    "\n",
    "try:\n",
    "    # 그래프 실행\n",
    "    stream_graph(app, inputs, config, [\"retrieve\", \"relevance_check\", \"llm_answer\"])\n",
    "except GraphRecursionError as recursion_error:\n",
    "    print(f\"GraphRecursionError: {recursion_error}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py-test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
